{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 4.4 自定义层\n",
    "## 4.4.1 不含模型参数的自定义层"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.4.1\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "from torch import nn\n",
    "\n",
    "print(torch.__version__)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "class CenteredLayer(nn.Module):\n",
    "    def __init__(self, **kwargs):\n",
    "        super(CenteredLayer, self).__init__(**kwargs)\n",
    "    def forward(self, x):\n",
    "        return x - x.mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([-2., -1.,  0.,  1.,  2.])"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "layer = CenteredLayer()\n",
    "layer(torch.tensor([1, 2, 3, 4, 5], dtype=torch.float))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "net = nn.Sequential(nn.Linear(8, 128), CenteredLayer())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.0"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y = net(torch.rand(4, 8))\n",
    "y.mean().item()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4.4.2 含模型参数的自定义层"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "MyListDense(\n",
      "  (params): ParameterList(\n",
      "      (0): Parameter containing: [torch.FloatTensor of size 4x4]\n",
      "      (1): Parameter containing: [torch.FloatTensor of size 4x4]\n",
      "      (2): Parameter containing: [torch.FloatTensor of size 4x4]\n",
      "      (3): Parameter containing: [torch.FloatTensor of size 4x1]\n",
      "  )\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "class MyListDense(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(MyListDense, self).__init__()\n",
    "        self.params = nn.ParameterList([nn.Parameter(torch.randn(4, 4)) for i in range(3)])\n",
    "        self.params.append(nn.Parameter(torch.randn(4, 1)))\n",
    "\n",
    "    def forward(self, x):\n",
    "        for i in range(len(self.params)):\n",
    "            x = torch.mm(x, self.params[i])\n",
    "        return x\n",
    "net = MyListDense()\n",
    "print(net)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "MyDictDense(\n",
      "  (params): ParameterDict(\n",
      "      (linear1): Parameter containing: [torch.FloatTensor of size 4x4]\n",
      "      (linear2): Parameter containing: [torch.FloatTensor of size 4x1]\n",
      "      (linear3): Parameter containing: [torch.FloatTensor of size 4x2]\n",
      "  )\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "class MyDictDense(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(MyDictDense, self).__init__()\n",
    "        self.params = nn.ParameterDict({\n",
    "                'linear1': nn.Parameter(torch.randn(4, 4)),\n",
    "                'linear2': nn.Parameter(torch.randn(4, 1))\n",
    "        })\n",
    "        self.params.update({'linear3': nn.Parameter(torch.randn(4, 2))}) # 新增\n",
    "\n",
    "    def forward(self, x, choice='linear1'):\n",
    "        return torch.mm(x, self.params[choice])\n",
    "\n",
    "net = MyDictDense()\n",
    "print(net)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[1.5082, 1.5574, 2.1651, 1.2409]], grad_fn=<MmBackward>)\n",
      "tensor([[-0.8783]], grad_fn=<MmBackward>)\n",
      "tensor([[ 2.2193, -1.6539]], grad_fn=<MmBackward>)\n"
     ]
    }
   ],
   "source": [
    "x = torch.ones(1, 4)\n",
    "print(net(x, 'linear1'))\n",
    "print(net(x, 'linear2'))\n",
    "print(net(x, 'linear3'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sequential(\n",
      "  (0): MyDictDense(\n",
      "    (params): ParameterDict(\n",
      "        (linear1): Parameter containing: [torch.FloatTensor of size 4x4]\n",
      "        (linear2): Parameter containing: [torch.FloatTensor of size 4x1]\n",
      "        (linear3): Parameter containing: [torch.FloatTensor of size 4x2]\n",
      "    )\n",
      "  )\n",
      "  (1): MyListDense(\n",
      "    (params): ParameterList(\n",
      "        (0): Parameter containing: [torch.FloatTensor of size 4x4]\n",
      "        (1): Parameter containing: [torch.FloatTensor of size 4x4]\n",
      "        (2): Parameter containing: [torch.FloatTensor of size 4x4]\n",
      "        (3): Parameter containing: [torch.FloatTensor of size 4x1]\n",
      "    )\n",
      "  )\n",
      ")\n",
      "tensor([[-101.2394]], grad_fn=<MmBackward>)\n"
     ]
    }
   ],
   "source": [
    "net = nn.Sequential(\n",
    "    MyDictDense(),\n",
    "    MyListDense(),\n",
    ")\n",
    "print(net)\n",
    "print(net(x))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [default]",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
